"""Immutable function vars."""

from __future__ import annotations

import dataclasses
from collections.abc import Callable, Sequence
from typing import Any, Concatenate, Generic, ParamSpec, Protocol, TypeVar, overload

from reflex.utils import format
from reflex.utils.types import GenericType

from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock

P = ParamSpec("P")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
V6 = TypeVar("V6")
R = TypeVar("R")


class ReflexCallable(Protocol[P, R]):
    """Protocol for a callable."""

    __call__: Callable[P, R]


CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True)
OTHER_CALLABLE_TYPE = TypeVar(
    "OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True
)


class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
    """Base class for immutable function vars."""

    @overload
    def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
        arg1: V1 | Var[V1],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
        arg5: V5 | Var[V5],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
        arg5: V5 | Var[V5],
        arg6: V6 | Var[V6],
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(
        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
    ) -> FunctionVar[ReflexCallable[P, R]]: ...

    @overload
    def partial(self, *args: Var | Any) -> FunctionVar: ...

    def partial(self, *args: Var | Any) -> FunctionVar:  # pyright: ignore [reportInconsistentOverload]
        """Partially apply the function with the given arguments.

        Args:
            *args: The arguments to partially apply the function with.

        Returns:
            The partially applied function.
        """
        if not args:
            return ArgsFunctionOperation.create((), self)
        return ArgsFunctionOperation.create(
            ("...args",),
            VarOperationCall.create(self, *args, Var(_js_expr="...args")),
        )

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1], R]], arg1: V1 | Var[V1]
    ) -> VarOperationCall[[V1], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1, V2], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
    ) -> VarOperationCall[[V1, V2], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
    ) -> VarOperationCall[[V1, V2, V3], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
    ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
        arg5: V5 | Var[V5],
    ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
        arg1: V1 | Var[V1],
        arg2: V2 | Var[V2],
        arg3: V3 | Var[V3],
        arg4: V4 | Var[V4],
        arg5: V5 | Var[V5],
        arg6: V6 | Var[V6],
    ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...

    @overload
    def call(
        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
    ) -> VarOperationCall[P, R]: ...

    @overload
    def call(self, *args: Var | Any) -> Var: ...

    def call(self, *args: Var | Any) -> Var:  # pyright: ignore [reportInconsistentOverload]
        """Call the function with the given arguments.

        Args:
            *args: The arguments to call the function with.

        Returns:
            The function call operation.
        """
        return VarOperationCall.create(self, *args).guess_type()

    __call__ = call


class BuilderFunctionVar(
    FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
):
    """Base class for immutable function vars with the builder pattern."""

    __call__ = FunctionVar.partial


class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
    """Base class for immutable function vars from a string."""

    @classmethod
    def create(
        cls,
        func: str,
        _var_type: type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
        _var_data: VarData | None = None,
    ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
        """Create a new function var from a string.

        Args:
            func: The function to call.
            _var_type: The type of the Var.
            _var_data: Additional hooks and imports associated with the Var.

        Returns:
            The function var.
        """
        return FunctionStringVar(
            _js_expr=func,
            _var_type=_var_type,
            _var_data=_var_data,
        )


@dataclasses.dataclass(
    eq=False,
    frozen=True,
    slots=True,
)
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
    """Base class for immutable vars that are the result of a function call."""

    _func: FunctionVar[ReflexCallable[P, R]] | None = dataclasses.field(default=None)
    _args: tuple[Var | Any, ...] = dataclasses.field(default_factory=tuple)

    @cached_property_no_lock
    def _cached_var_name(self) -> str:
        """The name of the var.

        Returns:
            The name of the var.
        """
        return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"

    @cached_property_no_lock
    def _cached_get_all_var_data(self) -> VarData | None:
        """Get all the var data associated with the var.

        Returns:
            All the var data associated with the var.
        """
        return VarData.merge(
            self._func._get_all_var_data() if self._func is not None else None,
            *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
            self._var_data,
        )

    @classmethod
    def create(
        cls,
        func: FunctionVar[ReflexCallable[P, R]],
        *args: Var | Any,
        _var_type: GenericType = Any,
        _var_data: VarData | None = None,
    ) -> VarOperationCall:
        """Create a new function call var.

        Args:
            func: The function to call.
            *args: The arguments to call the function with.
            _var_type: The type of the Var.
            _var_data: Additional hooks and imports associated with the Var.

        Returns:
            The function call var.
        """
        function_return_type = (
            func._var_type.__args__[1]
            if getattr(func._var_type, "__args__", None)
            else Any
        )
        var_type = _var_type if _var_type is not Any else function_return_type
        return cls(
            _js_expr="",
            _var_type=var_type,
            _var_data=_var_data,
            _func=func,
            _args=args,
        )


@dataclasses.dataclass(frozen=True)
class DestructuredArg:
    """Class for destructured arguments."""

    fields: tuple[str, ...] = ()
    rest: str | None = None

    def to_javascript(self) -> str:
        """Convert the destructured argument to JavaScript.

        Returns:
            The destructured argument in JavaScript.
        """
        return format.wrap(
            ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
            "{",
            "}",
        )


@dataclasses.dataclass(
    frozen=True,
)
class FunctionArgs:
    """Class for function arguments."""

    args: tuple[str | DestructuredArg, ...] = ()
    rest: str | None = None


def format_args_function_operation(
    args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
) -> str:
    """Format an args function operation.

    Args:
        args: The function arguments.
        return_expr: The return expression.
        explicit_return: Whether to use explicit return syntax.

    Returns:
        The formatted args function operation.
    """
    arg_names_str = ", ".join([
        arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args
    ]) + (f", ...{args.rest}" if args.rest else "")

    return_expr_str = str(LiteralVar.create(return_expr))

    # Wrap return expression in curly braces if explicit return syntax is used.
    return_expr_str_wrapped = (
        format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
    )

    return f"(({arg_names_str}) => {return_expr_str_wrapped})"


@dataclasses.dataclass(
    eq=False,
    frozen=True,
    slots=True,
)
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
    """Base class for immutable function defined via arguments and return expression."""

    _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
    _return_expr: Var | Any = dataclasses.field(default=None)
    _explicit_return: bool = dataclasses.field(default=False)

    @cached_property_no_lock
    def _cached_var_name(self) -> str:
        """The name of the var.

        Returns:
            The name of the var.
        """
        return format_args_function_operation(
            self._args, self._return_expr, self._explicit_return
        )

    @classmethod
    def create(
        cls,
        args_names: Sequence[str | DestructuredArg],
        return_expr: Var | Any,
        rest: str | None = None,
        explicit_return: bool = False,
        _var_type: GenericType = Callable,
        _var_data: VarData | None = None,
    ):
        """Create a new function var.

        Args:
            args_names: The names of the arguments.
            return_expr: The return expression of the function.
            rest: The name of the rest argument.
            explicit_return: Whether to use explicit return syntax.
            _var_type: The type of the Var.
            _var_data: Additional hooks and imports associated with the Var.

        Returns:
            The function var.
        """
        return_expr = Var.create(return_expr)
        return cls(
            _js_expr="",
            _var_type=_var_type,
            _var_data=_var_data,
            _args=FunctionArgs(args=tuple(args_names), rest=rest),
            _return_expr=return_expr,
            _explicit_return=explicit_return,
        )


@dataclasses.dataclass(
    eq=False,
    frozen=True,
    slots=True,
)
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
    """Base class for immutable function defined via arguments and return expression with the builder pattern."""

    _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
    _return_expr: Var | Any = dataclasses.field(default=None)
    _explicit_return: bool = dataclasses.field(default=False)

    @cached_property_no_lock
    def _cached_var_name(self) -> str:
        """The name of the var.

        Returns:
            The name of the var.
        """
        return format_args_function_operation(
            self._args, self._return_expr, self._explicit_return
        )

    @classmethod
    def create(
        cls,
        args_names: Sequence[str | DestructuredArg],
        return_expr: Var | Any,
        rest: str | None = None,
        explicit_return: bool = False,
        _var_type: GenericType = Callable,
        _var_data: VarData | None = None,
    ):
        """Create a new function var.

        Args:
            args_names: The names of the arguments.
            return_expr: The return expression of the function.
            rest: The name of the rest argument.
            explicit_return: Whether to use explicit return syntax.
            _var_type: The type of the Var.
            _var_data: Additional hooks and imports associated with the Var.

        Returns:
            The function var.
        """
        return_expr = Var.create(return_expr)
        return cls(
            _js_expr="",
            _var_type=_var_type,
            _var_data=_var_data,
            _args=FunctionArgs(args=tuple(args_names), rest=rest),
            _return_expr=return_expr,
            _explicit_return=explicit_return,
        )


JSON_STRINGIFY = FunctionStringVar.create(
    "JSON.stringify", _var_type=ReflexCallable[[Any], str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
    "Array.isArray", _var_type=ReflexCallable[[Any], bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
    "((__to_string) => __to_string.toString())",
    _var_type=ReflexCallable[[Any], str],
)
